import numpy as np
import pylab as pl
import matplotlib.pyplot as plt

# # # # # # # # # # # # # # # # # # # # # # # # 
# This script plots the convergence points
# from the convergence experiments. It reads
# an .npz file as generated by parse_log.py
# # # # # # # # # # # # # # # # # # # # # # # # 

# # # # # # # # # # # #
# # # P A R A M S # # #
# # # # # # # # # # # #

# Set filename
path = "data/"
fname = "TR31.npz"

#User-defined input.
fname = raw_input("Enter filename: ")

# Select device
W,B = 21,32

# colors for the repetitions
color = ('b','r','g','m','lime','0.5')

# # # # # # # # # # # # # # # # # # #
# # # E N D   O F   P A R A M S # # #
# # # # # # # # # # # # # # # # # # #



pl.rc("font", size=8)

# Load data and add to name space
X = np.load(path+fname)
for k,v in X.items():
    globals()[k] = v
param = param.tolist()

# pick device
dev_idx = (param['B'] == B) * (param['W'] == W)

# get mean convergence points and repetitions
# uses the mean and std of the read pulses each trial
prlist = np.unique(param['Pr'])
replist = np.unique(param['Rep'])
gMeans = []
gStds = []
for r in replist:
    gm_tmp = []
    gs_tmp = []
    for p in prlist:
        idx = dev_idx * (param['Pr'] == p) * (param['Rep'] == r)
        #gm_tmp.append( (1./write_resist[idx,-500:]).mean(1) )
        #gs_tmp.append( (1./write_resist[idx,-500:]).std(1) )
        gm_tmp.append( (1./read_resist[idx]).mean(1) )
        gs_tmp.append( (1./read_resist[idx]).std(1) )
    gMeans.append(np.array(gm_tmp).flatten())
    gStds.append(np.array(gs_tmp).flatten())

gMeans = np.array(gMeans)
gStds = np.array(gStds)

#########################################
# Perform fitting of convergence points.#
#########################################

#Module imports.
from scipy.optimize import leastsq

#Key parameters, variables, arrays
noConvs = 3 #Number of convergence runs taken in to consideration.
startRun = 1 #Which run to start from. Made to ignore initial 'startRun' runs.
Errmat = np.array([[0.0]*noConvs]*len(prlist)) #Generate FLOAT array to hold fitting/measured errors. Size is len(prlist) rows and noConvs columns.


#Compute average values for convergence points from 3 test runs.
data = 0
for i in range(startRun, noConvs+1, 1):
    data = data + gMeans[i] #Capture data to be fitted - 'loads' data into f function later on.

data = data/3 #Normalise average.

#Generate x-axis for data points being imported.
xdat = np.array(prlist) #x-axis of data in 'data'.

#Define fitting and residual functions.
h = lambda p,x: p[0] * np.exp(- p[1] * (x - p[2])) + p[3] # Fitting model function.
f = lambda p,x: data - h(p,x) # Residual function: actual data - model function.

#Define initial fitting guess.
fitguess = np.array([-0.000001, 1.0, 0.0, 0.000005]) #Array of initial guess parameters: [a, b, c, d] -> a*e^(-b*(x-c)) + d

#Perform fitting and store results.
fitting = leastsq(f, fitguess, args=(xdat,))[0] #Perform least squares fitting to minimise residual.

#Generate x-axis for PLOTTING THE FITTING.
fitx = np.linspace(0, 1, 101) #Create x-axis for fittings (from min = 0 to max = 1 in 100 steps).

#Extract maximum and minimum fractional deviations of data from fitting.
for i in range(0, len(prlist), 1):
    for j in range(startRun, noConvs+1, 1):
        Errmat[i,j-startRun] = 100.0*(gMeans[j][i] - h(fitting, prlist[i]))/h(fitting, prlist[i]) #Compute fractional error between data point and fitting in 0/00.

Errmax = np.max(Errmat)
Errmin = np.min(Errmat)

print(Errmat, Errmax, Errmin) #Display results.


#####################################################################################
# Fit convergence runs to exponentials, then extract extrapolated convergence point #
#####################################################################################

#Key parameters, variables and arrays.
fitrun = np.array([[0.0]*4]*len(write_resist)) #Caution: assumes all test runs of same length.
fitrunx = np.array([0.0]*len(write_resist))

#Define fitting and residual functions.
hrun = lambda p,x: p[0] * np.exp(- p[1] * (x - p[2])) + p[3] # Fitting model function.
frun = lambda p,x: data - hrun(p,x) # Residual function: actual data - model function.

#Define initial fitting guess.
fitguess = np.array([0.0, 1.0, 0.0, 0.000005]) #Array of initial guess parameters: [a, b, c, d] -> a*e^(-b*(x-c)) + d

#Sweep test runs.
for i in range(0, len(write_resist), 1):
    data = 1/np.array(write_resist[i]) #Capture write streak data.
    xdat = range(0, len(write_resist[i]), 1) #Capture x-data (simply event indices).

    #Perform fitting and store results.
    fitrun[i,:] = leastsq(frun, fitguess, args=(xdat,))[0] #Perform least squares fitting to minimise residual.
    fitrunx[i] = param['Pr'][i]
    print('Run ', str(i+1), ' extrapolated convergence to ', fitrun[i,3]) #Show extrapolated final values -> param. no. 3 of the fitting: constant offset.

############
# plotting #
############
fig = pl.figure(figsize=(5,3))
xs = np.array(prlist) #x-axis for convergence point plots.
for r,(ys,yerrs) in enumerate(zip(gMeans,gStds)):
    pl.errorbar(xs, ys, yerr=yerrs, fmt='o', c=color[r], label="Rep %d" % r)

ymin,ymax = gMeans.min(), gMeans.max()
dy = ymax-ymin
pl.ylim(ymin-0.1*dy, ymax+0.1*dy)
pl.xlim(-0.05,1.05)

pl.xlabel("p(y=1|z=1)")
pl.xticks(prlist)
pl.ylabel("Convergence conductance [S]")

pl.title("Device coords (W:%d, B:%d)" % (W,B), fontsize=8)
# correct units
plt.ticklabel_format(style='sci', axis='y', scilimits=(0,0))
# correct layout
fig.subplots_adjust(0.12,0.15,0.95,0.9)

leg = pl.legend(loc="upper left")

#Plot fittings
#pl.plot(fitx, len(fitx)*[0.000004], 'r', lw=2.) #Plots a constant line at 4e-6 the check plotting functionality.
pl.plot(fitx, h(fitting, fitx), 'b', lw=2.) # MAKE IT COLOUR-CODED.


#Plot extrapolated data for convergence points.
extraconv = np.array(fitrun[:,3]) #Store extrapolated convergence points in a dedicated array.
extraconv = extraconv.reshape(len(replist), len(fitrun[:,3])/len(replist))
fitrunxresh = fitrunx.reshape(len(replist), len(fitrunx)/len(replist))

for i in  range(0, len(extraconv), 1):
    for j in xrange(0, len(extraconv[0]), 1):
        pl.scatter(fitrunxresh[i,j],extraconv[i,j], s=60, marker='x', c=color[i])

#fig2 = pl.figure(figsize=(5,3))
#ax2 = fig2.add_axes((0.08, 0.17, 0.85, 0.75))
#ax2.plot(fitx, h(fitting, fitx), 'b', lw=1.)
ax = pl.show()
